{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# OOB Errors for Random Forests\n\n\nThe ``RandomForestClassifier`` is trained using *bootstrap aggregation*, where\neach new tree is fit from a bootstrap sample of the training observations\n$z_i = (x_i, y_i)$. The *out-of-bag* (OOB) error is the average error for\neach $z_i$ calculated using predictions from the trees that do not\ncontain $z_i$ in their respective bootstrap sample. This allows the\n``RandomForestClassifier`` to be fit and validated whilst being trained [1]_.\n\nThe example below demonstrates how the OOB error can be measured at the\naddition of each new tree during training. The resulting plot allows a\npractitioner to approximate a suitable value of ``n_estimators`` at which the\nerror stabilizes.\n\n.. [1] T. Hastie, R. Tibshirani and J. Friedman, \"Elements of Statistical\n       Learning Ed. 2\", p592-593, Springer, 2009.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n\nfrom collections import OrderedDict\nfrom sklearn.datasets import make_classification\nfrom sklearn.ensemble import RandomForestClassifier\n\n# Author: Kian Ho <hui.kian.ho@gmail.com>\n#         Gilles Louppe <g.louppe@gmail.com>\n#         Andreas Mueller <amueller@ais.uni-bonn.de>\n#\n# License: BSD 3 Clause\n\nprint(__doc__)\n\nRANDOM_STATE = 123\n\n# Generate a binary classification dataset.\nX, y = make_classification(n_samples=500, n_features=25,\n                           n_clusters_per_class=1, n_informative=15,\n                           random_state=RANDOM_STATE)\n\n# NOTE: Setting the `warm_start` construction parameter to `True` disables\n# support for parallelized ensembles but is necessary for tracking the OOB\n# error trajectory during training.\nensemble_clfs = [\n    (\"RandomForestClassifier, max_features='sqrt'\",\n        RandomForestClassifier(warm_start=True, oob_score=True,\n                               max_features=\"sqrt\",\n                               random_state=RANDOM_STATE)),\n    (\"RandomForestClassifier, max_features='log2'\",\n        RandomForestClassifier(warm_start=True, max_features='log2',\n                               oob_score=True,\n                               random_state=RANDOM_STATE)),\n    (\"RandomForestClassifier, max_features=None\",\n        RandomForestClassifier(warm_start=True, max_features=None,\n                               oob_score=True,\n                               random_state=RANDOM_STATE))\n]\n\n# Map a classifier name to a list of (<n_estimators>, <error rate>) pairs.\nerror_rate = OrderedDict((label, []) for label, _ in ensemble_clfs)\n\n# Range of `n_estimators` values to explore.\nmin_estimators = 15\nmax_estimators = 175\n\nfor label, clf in ensemble_clfs:\n    for i in range(min_estimators, max_estimators + 1):\n        clf.set_params(n_estimators=i)\n        clf.fit(X, y)\n\n        # Record the OOB error for each `n_estimators=i` setting.\n        oob_error = 1 - clf.oob_score_\n        error_rate[label].append((i, oob_error))\n\n# Generate the \"OOB error rate\" vs. \"n_estimators\" plot.\nfor label, clf_err in error_rate.items():\n    xs, ys = zip(*clf_err)\n    plt.plot(xs, ys, label=label)\n\nplt.xlim(min_estimators, max_estimators)\nplt.xlabel(\"n_estimators\")\nplt.ylabel(\"OOB error rate\")\nplt.legend(loc=\"upper right\")\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}